{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0fb7b40add6e49e1b82b4a60242c7d1e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d4d21438d056453fb8e5ccd8bb568246",
              "IPY_MODEL_e88fb8bc2fb94cf2ac582f73e9396be3",
              "IPY_MODEL_5f64fedf36cc4872a09b9bf2812a6184"
            ],
            "layout": "IPY_MODEL_a47a6b81157d4f0ca4baca71fc276465"
          }
        },
        "d4d21438d056453fb8e5ccd8bb568246": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_be06b0f685ef458b8f9b946d1170e478",
            "placeholder": "​",
            "style": "IPY_MODEL_f87b4a192c4144a8a5eb8e793ffda51e",
            "value": "100%"
          }
        },
        "e88fb8bc2fb94cf2ac582f73e9396be3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d4e564fcb77848d983f7db5b12ccc572",
            "max": 20,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_6e9626d1f84a464886d0d319ec474186",
            "value": 20
          }
        },
        "5f64fedf36cc4872a09b9bf2812a6184": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d1f7cb270d0f4959b1876eaa88cf8736",
            "placeholder": "​",
            "style": "IPY_MODEL_042cd49ddb3049b8b6681f461c836221",
            "value": " 20/20 [00:02&lt;00:00,  7.92it/s]"
          }
        },
        "a47a6b81157d4f0ca4baca71fc276465": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "be06b0f685ef458b8f9b946d1170e478": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f87b4a192c4144a8a5eb8e793ffda51e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d4e564fcb77848d983f7db5b12ccc572": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6e9626d1f84a464886d0d319ec474186": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d1f7cb270d0f4959b1876eaa88cf8736": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "042cd49ddb3049b8b6681f461c836221": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "<b>Use Gemini to prompt Stable Diffusion</b>\n",
        "\n",
        "Finding the right words to prompt an image generator can be a chore. Use Google's latest model release, Gemini, to prompt Stable Diffusion to produce amazing generated imagery.\n",
        "\n",
        "This notebook includes modified code from [woctezuma/stable-diffusion-colab](https://github.com/woctezuma/stable-diffusion-colab). Thanks."
      ],
      "metadata": {
        "id": "ZWhWniBGu3_Y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Check for a GPU\n",
        "#todo fail if this fails\n",
        "\n",
        "import os\n",
        "\n",
        "if os.system('nvidia-smi'):\n",
        "  raise Exception(\"No GPU found. Access a GPU through Runtime > Change runtime type and try again.\")"
      ],
      "metadata": {
        "cellView": "form",
        "id": "Opn0sFIkuoVw"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Configure Gemini API key\n",
        "\n",
        "#Access your Gemini API key\n",
        "\n",
        "import google.generativeai as genai\n",
        "from google.colab import userdata\n",
        "\n",
        "gemini_api_secret_name = 'GOOGLE_API_KEY'  # @param {type: \"string\"}\n",
        "\n",
        "try:\n",
        "  GOOGLE_API_KEY=userdata.get(gemini_api_secret_name)\n",
        "  genai.configure(api_key=GOOGLE_API_KEY)\n",
        "except userdata.SecretNotFoundError as e:\n",
        "   print(f'''Secret not found\\n\\nThis expects you to create a secret named {gemini_api_secret_name} in Colab\\n\\nVisit https://makersuite.google.com/app/apikey to create an API key\\n\\nStore that in the secrets section on the left side of the notebook (key icon)\\n\\nName the secret {gemini_api_secret_name}''')\n",
        "   raise e\n",
        "except userdata.NotebookAccessError as e:\n",
        "  print(f'''You need to grant this notebook access to the {gemini_api_secret_name} secret in order for the notebook to access Gemini on your behalf.''')\n",
        "  raise e\n",
        "except Exception as e:\n",
        "  # unknown error\n",
        "  print(f\"There was an unknown error. Ensure you have a secret {gemini_api_secret_name} stored in Colab and it's a valid key from https://makersuite.google.com/app/apikey\")\n",
        "  raise e\n",
        "\n",
        "model = genai.GenerativeModel('gemini-pro')"
      ],
      "metadata": {
        "id": "yFv1abRcv2P2",
        "cellView": "form"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "CypmzyrIrc9k"
      },
      "outputs": [],
      "source": [
        "#@title Setup dependencies & pipeline - this takes a bit\n",
        "\n",
        "%pip install --quiet --upgrade diffusers accelerate mediapy\n",
        "\n",
        "import mediapy as media, random, sys, torch\n",
        "from diffusers import AutoPipelineForText2Image\n",
        "\n",
        "pipe = AutoPipelineForText2Image.from_pretrained(\n",
        "    \"stabilityai/sdxl-turbo\",\n",
        "    torch_dtype=torch.float16,\n",
        "    use_safetensors=True,\n",
        "    variant=\"fp16\",\n",
        "    )\n",
        "\n",
        "pipe = pipe.to(\"cuda\")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Use Gemini to create a text prompt with to feed to Stable Diffusion\n",
        "\n",
        "model = genai.GenerativeModel('gemini-pro')\n",
        "\n",
        "text = 'Draw a kitty cat typing on a computer' # @param {type:\"string\"}\n",
        "\n",
        "prompt = \"You are creating a prompt for Stable Diffusion to generate an image. Please generate a text prompt for %s. Only respond with the prompt itself, but embellish it as needed but keep it under 80 tokens.\" % text\n",
        "response = model.generate_content(prompt)\n",
        "response.text"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "GfBJHe1NusKt",
        "outputId": "633810e8-295c-4043-c111-5de01782ed41",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\"[A small calico kitten with dark brown stripes sits on a computer chair, typing on a keyboard with a serious expression. The kitten's tail flicks back and forth in excitement as it sends an email. The background shows a cluttered home office with bookshelves and various office supplies. Draw in a cute anime style.]\""
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Generate the image with Stable Diffusion\n",
        "\n",
        "prompt = response.text\n",
        "seed = random.randint(0, sys.maxsize)\n",
        "\n",
        "num_inference_steps = 20\n",
        "\n",
        "images = pipe(\n",
        "    prompt = prompt,\n",
        "    guidance_scale = 0.0,\n",
        "    num_inference_steps = num_inference_steps,\n",
        "    generator = torch.Generator(\"cuda\").manual_seed(seed),\n",
        "    ).images\n",
        "\n",
        "media.show_images(images)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 567,
          "referenced_widgets": [
            "0fb7b40add6e49e1b82b4a60242c7d1e",
            "d4d21438d056453fb8e5ccd8bb568246",
            "e88fb8bc2fb94cf2ac582f73e9396be3",
            "5f64fedf36cc4872a09b9bf2812a6184",
            "a47a6b81157d4f0ca4baca71fc276465",
            "be06b0f685ef458b8f9b946d1170e478",
            "f87b4a192c4144a8a5eb8e793ffda51e",
            "d4e564fcb77848d983f7db5b12ccc572",
            "6e9626d1f84a464886d0d319ec474186",
            "d1f7cb270d0f4959b1876eaa88cf8736",
            "042cd49ddb3049b8b6681f461c836221"
          ]
        },
        "id": "25Aon1_AvsKw",
        "outputId": "d96d0f12-6298-418b-efb1-dabb8d5cee2b",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/20 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "0fb7b40add6e49e1b82b4a60242c7d1e"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<table class=\"show_images\" style=\"border-spacing:0px;\"><tr><td style=\"padding:1px;\"><img width=\"512\" height=\"512\" style=\"image-rendering:auto; object-fit:cover;\" src=\"\"/></td></tr></table>"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ]
}